/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.util.*;

import org.w3c.dom.*;
import javax.xml.parsers.*;
import  org.w3c.dom.bootstrap.DOMImplementationRegistry;
import  org.w3c.dom.Document;
import  org.w3c.dom.ls.*;

import java.io.*;
import java.nio.charset.Charset;

import edu.umd.clip.lm.factors.*;
import edu.umd.clip.lm.util.tree.*;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.util.*;
import edu.umd.clip.lm.util.algo.*;
import edu.umd.clip.jobs.*;

import edu.berkeley.nlp.util.*;
/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class WordClustererMDI {

	public static class Options {
        @Option(name = "-input", required = false, usage = "Training data file (Default: stdin)")
		public String input;
        @Option(name = "-config", usage = "XML config file")
		public String config;
        @Option(name = "-jobs", usage = "number of concurrent jobs (default: 1)")
        public int jobs = 1;
	}
	/**
	 * @param args
	 */
	public static void main(String[] args) {
        OptionParser optParser = new OptionParser(Options.class);
        Options opts = (Options) optParser.parse(args, true);

        JobManager.initialize(opts.jobs);
		Thread thread = new Thread(JobManager.getInstance(), "Job Manager");
		thread.setDaemon(true);
		thread.start();

		Experiment.initialize(opts.config);
		Experiment experiment = Experiment.getInstance();

		FactorTupleDescription tupleDesc = experiment.getTupleDescription();
		byte factorIndex = tupleDesc.getMainFactorIndex();
		edu.umd.clip.lm.factors.Dictionary dictionary = tupleDesc.getDictionary(factorIndex);
		
		// now collect the statistics
		// TODO: LM order is assumed to be 3
		FLMInputParser parser = new FLMInputParser(tupleDesc);
		long[] startTuples = new long[1];
		startTuples[0] = tupleDesc.createStartTuple();
		//startTuples[1] = tupleDesc.createStartTuple();
		parser.setStartTuples(startTuples);
		long[] endTuples = new long[1];
		endTuples[0] = tupleDesc.createEndTuple();
		parser.setEndTuples(endTuples);

		HashMap<String, HashMap<String, MutableDouble>> leftCounts;
		HashMap<String, HashMap<String, MutableDouble>> rightCounts;
		leftCounts = new HashMap<String, HashMap<String, MutableDouble>>(dictionary.size());
		rightCounts = new HashMap<String, HashMap<String, MutableDouble>>(dictionary.size());

		double totalCounts = 0.0;
		try {
			InputStreamReader reader = new InputStreamReader(opts.input == null ? System.in : new FileInputStream(opts.input), 
					Charset.forName("UTF-8"));
			BufferedReader inputData = new BufferedReader(reader);

			for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
				long[] sentence = parser.parseSentence(line);
				
				for(int i=0; i<sentence.length-1; ++i) {
					String word1 = dictionary.getWord(FactorTuple.getValue(sentence[i], factorIndex));
					String word2 = dictionary.getWord(FactorTuple.getValue(sentence[i+1], factorIndex));
					
					HashMap<String, MutableDouble> leftMap = leftCounts.get(word2);
					if (leftMap == null) {
						leftMap = new HashMap<String, MutableDouble>(leftCounts.size());
						leftCounts.put(word2, leftMap);
					}
					MutableDouble count = leftMap.get(word1);
					if (count == null) {
						leftMap.put(word1, new MutableDouble(1.0));
					} else {
						count.set(count.doubleValue() + 1.0);
					}

					HashMap<String, MutableDouble> rightMap = rightCounts.get(word1);
					if (rightMap == null) {
						rightMap = new HashMap<String, MutableDouble>(rightCounts.size());
						rightCounts.put(word1, rightMap);
					}
					count = rightMap.get(word2);
					if (count == null) {
						rightMap.put(word2, new MutableDouble(1.0));
					} else {
						count.set(count.doubleValue() + 1.0);
					}
						
					totalCounts += 1.0;
					//System.out.println("size = " + Integer.toString(counts.size()));
				}
			}
		} catch(IOException e) {
			e.printStackTrace();
		}
				
		for(HashMap<String, MutableDouble> entry : leftCounts.values()) {
			// normalize
			for(MutableDouble count : entry.values()) {
				count.set(count.doubleValue() / totalCounts);
			}
		}

		for(HashMap<String, MutableDouble> entry : rightCounts.values()) {
			// normalize
			for(MutableDouble count : entry.values()) {
				count.set(count.doubleValue() / totalCounts);
			}
		}
			
		final HashMap<Collection<String>, AnyaryTree<String>> treeNodes = new HashMap<Collection<String>, AnyaryTree<String>>();
		MDIClusterNotifier<String> MDINotifier = new MDIClusterNotifier<String>() {

			/* (non-Javadoc)
			 * @see edu.umd.clip.lm.util.algo.MDIClusterNotifier#notify(java.util.Collection, java.util.Collection, java.util.Collection)
			 */
			public boolean notify(Collection<String> oldCluster,
					Collection<String> cluster1, Collection<String> cluster2) {
				AnyaryTree<String> node = treeNodes.get(oldCluster);
				String label1 = null;
				if (cluster1.size() == 1) {
					label1 = cluster1.iterator().next();
				}
				AnyaryTree<String> subnode1 = new AnyaryTree<String>(label1);
				node.addChild(subnode1);
				
				String label2 = null;
				if (cluster2.size() == 1) {
					label2 = cluster2.iterator().next();
				}
				AnyaryTree<String> subnode2 = new AnyaryTree<String>(label2);
				node.addChild(subnode2);
				
				treeNodes.put(cluster1, subnode1);
				treeNodes.put(cluster2, subnode2);
				return true;
			}
		};
		
		AnyaryTree<String> theRoot = null;
		HashSet<String> vocab = new HashSet<String>(100);
		
		for(DictionaryIterator it = dictionary.iterator(true); it.hasNext(); ) {
			vocab.add(dictionary.getWord(it.next()));
		}

		AnyaryTree<String> treeRoot = new AnyaryTree<String>(null);
		treeNodes.put(vocab, treeRoot);
		if (theRoot == null) {
			theRoot = treeRoot;
		}

		MDI<String,String> algo = new MDI<String,String>(vocab, vocab);
		algo.setNotifier(MDINotifier);

		for(Map.Entry<String, HashMap<String, MutableDouble>> entry : leftCounts.entrySet()) {
			String word = entry.getKey();
			for(Map.Entry<String, MutableDouble> entry2 : entry.getValue().entrySet()) {
				algo.setLeftProb(word, entry2.getKey(), entry2.getValue().doubleValue());
			}
		}
		leftCounts = null;
		
		for(Map.Entry<String, HashMap<String, MutableDouble>> entry : rightCounts.entrySet()) {
			String word = entry.getKey();
			for(Map.Entry<String, MutableDouble> entry2 : entry.getValue().entrySet()) {
				algo.setRightProb(word, entry2.getKey(), entry2.getValue().doubleValue());
			}
		}
		rightCounts = null;
		
		algo.normalizeDistributions();
		algo.partition(vocab);

		try {
			FileWriter treeOutput = new FileWriter("word-tree.dot");
			GraphvizOutput graph = new GraphvizOutput("HFT");
			treeOutput.write(graph.getOutput(theRoot));
			treeOutput.close();
		} catch(IOException e) {
			e.printStackTrace();
		}
		
		WordTree<WordTreePayload> wordTree = new WordTree<WordTreePayload>(makeWordTree(new BinaryPrefix(new BitSet(0), 0), theRoot));
		experiment.setWordTree(wordTree);
		try {
			experiment.saveConfig(opts.config);
		} catch(IOException e) {
			e.printStackTrace();
		}
		
		DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
		DocumentBuilder db;
		try {
			db = dbf.newDocumentBuilder();
		}catch(ParserConfigurationException pce) {
			pce.printStackTrace();
			return;
		}

		Document hftDoc = db.getDOMImplementation().createDocument(null, null, null);
		Element elem = wordTree.createXML(hftDoc);
		hftDoc.appendChild(elem);
		try {
			FileWriter hftOutput = new FileWriter("word-tree.xml");
			
			DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();

			DOMImplementationLS impl = 
			    (DOMImplementationLS)registry.getDOMImplementation("LS");
			LSSerializer writer = impl.createLSSerializer();
			writer.getDomConfig().setParameter("format-pretty-print", Boolean.TRUE);
			LSOutput output = impl.createLSOutput();
			output.setEncoding("UTF-8");
			output.setCharacterStream(hftOutput);
			writer.write(hftDoc.getDocumentElement(), output);
			hftOutput.close();
		} catch(Exception e) {
			e.printStackTrace();
		}
		
	}
	
	private static BinaryTree<WordTreePayload> makeWordTree(BinaryPrefix prefix, AnyaryTree<String> node) {
		String word = node.getPayload();
		BinaryTree<WordTreePayload> tree = new BinaryTree<WordTreePayload>(new WordTreePayload(prefix, word));

		HashSet<AnyaryTree<String>> children = node.getChildren();
		if (children != null) {
			assert(children.size() == 2);
			Iterator<AnyaryTree<String>> iterator = children.iterator();
			AnyaryTree<String> left = iterator.next();
			AnyaryTree<String> right = iterator.next();
			
			tree.attachLeft(makeWordTree(prefix.appendZero(), left));
			tree.attachRight(makeWordTree(prefix.appendOne(), right));
		}
		return tree;
	}
}
